1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
| #include <bits/stdc++.h> #define rep(i, x, y) for (int i = x; i <= y; i++) using namespace std;
const int N = 2e5 + 10; typedef long long ll; int n, m, idx; int col[N], lst[N], rt[N]; ll val[N]; vector<int> nxt[N]; struct node { int cnt, ls, rs; ll sum; }tr[N * 60];
void upd(int x) { int ls = tr[x].ls, rs = tr[x].rs; tr[x].cnt = tr[ls].cnt + tr[rs].cnt; tr[x].sum = tr[ls].sum + tr[rs].sum; }
void modify(int &x, int l, int r, int pos, int v1, int v2) { if (!x) x = ++idx; if (l == r) { tr[x].cnt += v1; tr[x].sum += v2; return; } int mid = (l + r) >> 1; if (pos <= mid) modify(tr[x].ls, l, mid, pos, v1, v2); else modify(tr[x].rs, mid + 1, r, pos, v1, v2); upd(x); }
void merge(int &x, int y, int l, int r, int id) { if (!x || !y) { x = (x | y); return; } if (l == r) { tr[x].cnt += tr[y].cnt, tr[x].sum += tr[y].sum; return; } val[id] += 1ll * tr[tr[x].rs].cnt * tr[tr[y].ls].sum; val[id] += 1ll * tr[tr[x].ls].sum * tr[tr[y].rs].cnt; int mid = (l + r) >> 1; merge(tr[x].ls, tr[y].ls, l, mid, id); merge(tr[x].rs, tr[y].rs, mid + 1, r, id); upd(x); }
void dfs(int x, int fa) { for (int i = 0; i < nxt[x].size(); i++) { int y = nxt[x][i]; if (y == fa) continue; dfs(y, x); merge(rt[x], rt[y], 0, m, x); } }
int main() { cin >> n >> m; rep(i, 1, n) { scanf("%d", &col[i]); if (col[i]) modify(rt[i], 0, m, 0, 0, i); } rep(i, 1, n - 1) { int x, y; scanf("%d%d", &x, &y); nxt[x].push_back(y), nxt[y].push_back(x); } rep(i, 1, m) { int op, x; scanf("%d%d", &op, &x); if (op == 1) { modify(rt[x], 0, m, i, 1, 0); if (col[x]) val[x] += x; } else { col[x] ^= 1; if (col[x]) { modify(rt[x], 0, m, i, 0, x); } else { modify(rt[x], 0, m, i, 0, -x); } } } dfs(1, 0); rep(i, 1, n) printf("%lld\n", val[i]); return 0; }
|